{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nCompile TFLite Models\n=====================\n**Author**: `Zhao Wu <https://github.com/FrozenGene>`_\n\nThis article is an introductory tutorial to deploy TFLite models with Relay.\n\nTo get started, TFLite package needs to be installed as prerequisite.\n\n.. code-block:: bash\n\n    # install tflite\n    pip install tflite==2.1.0 --user\n\n\nor you could generate TFLite package yourself. The steps are the following:\n\n.. code-block:: bash\n\n    # Get the flatc compiler.\n    # Please refer to https://github.com/google/flatbuffers for details\n    # and make sure it is properly installed.\n    flatc --version\n\n    # Get the TFLite schema.\n    wget https://raw.githubusercontent.com/tensorflow/tensorflow/r1.13/tensorflow/lite/schema/schema.fbs\n\n    # Generate TFLite package.\n    flatc --python schema.fbs\n\n    # Add current folder (which contains generated tflite module) to PYTHONPATH.\n    export PYTHONPATH=${PYTHONPATH:+$PYTHONPATH:}$(pwd)\n\n\nNow please check if TFLite package is installed successfully, ``python -c \"import tflite\"``\n\nBelow you can find an example on how to compile TFLite model using TVM.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Utils for downloading and extracting zip files\n----------------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\n\n\ndef extract(path):\n    import tarfile\n\n    if path.endswith(\"tgz\") or path.endswith(\"gz\"):\n        dir_path = os.path.dirname(path)\n        tar = tarfile.open(path)\n        tar.extractall(path=dir_path)\n        tar.close()\n    else:\n        raise RuntimeError(\"Could not decompress the file: \" + path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load pretrained TFLite model\n----------------------------\nLoad mobilenet V1 TFLite model provided by Google\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm.contrib.download import download_testdata\n\nmodel_url = \"http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz\"\n\n# Download model tar file and extract it to get mobilenet_v1_1.0_224.tflite\nmodel_path = download_testdata(model_url, \"mobilenet_v1_1.0_224.tgz\", module=[\"tf\", \"official\"])\nmodel_dir = os.path.dirname(model_path)\nextract(model_path)\n\n# Now we can open mobilenet_v1_1.0_224.tflite\ntflite_model_file = os.path.join(model_dir, \"mobilenet_v1_1.0_224.tflite\")\ntflite_model_buf = open(tflite_model_file, \"rb\").read()\n\n# Get TFLite model from buffer\ntry:\n    import tflite\n\n    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)\nexcept AttributeError:\n    import tflite.Model\n\n    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load a test image\n-----------------\nA single cat dominates the examples!\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from PIL import Image\nfrom matplotlib import pyplot as plt\nimport numpy as np\n\nimage_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\nimage_path = download_testdata(image_url, \"cat.png\", module=\"data\")\nresized_image = Image.open(image_path).resize((224, 224))\nplt.imshow(resized_image)\nplt.show()\nimage_data = np.asarray(resized_image).astype(\"float32\")\n\n# Add a dimension to the image so that we have NHWC format layout\nimage_data = np.expand_dims(image_data, axis=0)\n\n# Preprocess image as described here:\n# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243\nimage_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1\nimage_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1\nimage_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1\nprint(\"input\", image_data.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compile the model with relay\n----------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# TFLite input tensor name, shape and type\ninput_tensor = \"input\"\ninput_shape = (1, 224, 224, 3)\ninput_dtype = \"float32\"\n\n# Parse TFLite model and convert it to a Relay module\nfrom tvm import relay, transform\n\nmod, params = relay.frontend.from_tflite(\n    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}\n)\n\n# Build the module against to x86 CPU\ntarget = \"llvm\"\nwith transform.PassContext(opt_level=3):\n    lib = relay.build(mod, target, params=params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Execute on TVM\n--------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\nfrom tvm import te\nfrom tvm.contrib import graph_executor as runtime\n\n# Create a runtime executor module\nmodule = runtime.GraphModule(lib[\"default\"](tvm.cpu()))\n\n# Feed input data\nmodule.set_input(input_tensor, tvm.nd.array(image_data))\n\n# Run\nmodule.run()\n\n# Get output\ntvm_output = module.get_output(0).numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Display results\n---------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Load label file\nlabel_file_url = \"\".join(\n    [\n        \"https://raw.githubusercontent.com/\",\n        \"tensorflow/tensorflow/master/tensorflow/lite/java/demo/\",\n        \"app/src/main/assets/\",\n        \"labels_mobilenet_quant_v1_224.txt\",\n    ]\n)\nlabel_file = \"labels_mobilenet_quant_v1_224.txt\"\nlabel_path = download_testdata(label_file_url, label_file, module=\"data\")\n\n# List of 1001 classes\nwith open(label_path) as f:\n    labels = f.readlines()\n\n# Convert result to 1D data\npredictions = np.squeeze(tvm_output)\n\n# Get top 1 prediction\nprediction = np.argmax(predictions)\n\n# Convert id to class name and show the result\nprint(\"The image prediction result is: id \" + str(prediction) + \" name: \" + labels[prediction])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}